Introduction to SHAP¶


Dr. Jose Maria Alonso-Moral

Centro Singular de Investigación en Tecnoloxías Intelixentes (CiTIUS)

Universidade de Santiago de Compostela

ETSE-USC, Campus Vida, Santiago de Compostela, Spain


Mr. Pablo Miguel Perez-Ferreiro

Centro Singular de Investigación en Tecnoloxías Intelixentes (CiTIUS)

Universidade de Santiago de Compostela

ETSE-USC, Campus Vida, Santiago de Compostela, Spain


27 January 2026

Explainable and Trustworthy AI


MIA.png

1. Introduction¶

This interactive tutorial includes supplementary material for the first TXAI in the Lab hands-on session (I1. "Introduction to SHAP") in the Subject Explainable and Trustworthy AI (Master in Artificial Intelligence). The session is leaded by Jose M. Alonso-Moral and Pablo Miguel Perez-Ferreiro at USC, Samuel SuƔrez Marcote) at UDC, and David Nicholas Olivieri Cecchi at UVigo.

2. Settings¶

In this section, we prepare the software needed to run the notebook. Please abstain from changing anything in this section unless you're prompted to by your teachers, as it may break the notebook's functionality.

InĀ [2]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

import os

os.system("pip install simplenlg --quiet")
os.system("pip install tabulate --quiet")
os.system("pip install numba==0.59.1 --quiet")
os.system("pip install salib==1.3.3 --quiet")
os.system("pip install numpy==1.26.4 --quiet")
os.system("pip install shap==0.46.0 --quiet")
os.system("pip install interpret==0.5.0 --quiet")
Out[2]:
0
InĀ [3]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

import warnings
warnings.filterwarnings('ignore')

# Loading plot tool (for ploting fuzzy sets and rules)
import matplotlib.pyplot as plt

# Loading pandas for their DataFrames and some management functions
import pandas as pd

# Loading seaborn for plotting
import seaborn as sns

# Loading numpy for utilities
import numpy as np

# Loading lib to deal with arff files
from scipy.io.arff import loadarff

# Loading sklearn and several of its modules: they will allow us to build and validate the models whose interpretability we will study
import sklearn
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import cross_validate
from sklearn.metrics import classification_report
from sklearn.tree import export_text

# Loading InterpretML, a library specialized in offering explainable by design models.
from interpret import set_visualize_provider
from interpret.provider import InlineProvider
set_visualize_provider(InlineProvider())
from interpret.glassbox import ExplainableBoostingClassifier
from interpret import show

# Loading a library to handle SHAP Values
import shap

# Loading Graphviz in order to plot Decision Trees
import graphviz
InĀ [4]:
## this code is developed by Jose Maria Alonso-Moral
## Auxiliary functions for the notebook:

# Plots a Pareto fron between two variables (x, y), labeled (labelx, labely), with the x-axis' limits defined (minx, maxx)
def plot_pareto_front(x,y,n,labelx,labely,minx,maxx):
  plt.title("Pareto Front")
  plt.ylabel(labelx)
  plt.xlabel(labely)
  plt.axis([minx, maxx, 0, 1])
  c=["ro","bo","go","rs","bs","gs","r*","b*","g*","r+","b+","g+"]
  for m in n:
    m_idx= n.index(m)
    plt.plot(x[m_idx], y[m_idx], c[m_idx], label=m)

  plt.grid(True)
  plt.legend()
  plt.show()

# Extracts the length of SHAP-based explanations; used on this practical as a surrogate interpretability metric
def get_shap_explanation_length(single_lower_triangular_interactions, indexes=None, th=0.9):
    # Calculate the cumulative sum and absolute value using numpy functions
    shap_cumsum = np.cumsum(np.abs(single_lower_triangular_interactions))
    # Normalize the cumulative sum
    normalised_shap_cumsum = shap_cumsum / shap_cumsum[-1]
    # Find the index of the first element that exceeds the threshold using a loop
    first_above_idx = 0
    for i, val in enumerate(normalised_shap_cumsum):
        if val > th:
            first_above_idx = i
            break
    # Calculate the SHAP explanation length
    shap_expl_length = first_above_idx
    return shap_expl_length

3. Interpretable systems: from the ground up¶

In this section, we will go through the process of building an interpretable system step-by-step, working with both interpretable-by-design and opaque models. In this sense, we will:

  • Import and pre-process a suitable dataset.
  • Build and validate several ML models for it.
  • Explore the properties of the built models, and perform a first attempt at interpreting those which are transparent through direct inspection.
  • Utilize SHAP Values as a post-hoc explainability tool that allows us to overcome some of the problems brought by direct inspection.
  • Perform a preliminary evaluation of the trade-offs between performance and explainability that the field of TXAI usually involves.

3.1. Loading data and preliminary exploration¶

We will be working with the Pima Indians Diabetes Dataset in the usual Weka arff format.

A classical example dataset for machine learning, it contains 768 instances, all of them describing females of age coming from the Pima Indian heritage. Its intended usage is binary classification, predicting the onset of diabetes according to 8 features defined by the World Health Organization:

  • Number of times pregnant
  • Plasma glucose concentration
  • Diastolic blood pressure
  • Triceps skin fold thickness
  • Two hour serum insulin
  • Body mass index
  • Diabetes pedigree function
  • Age

We will begin by importing it from your workspace, and then we will pre-process the data and explore it.

InĀ [4]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

# Importing the train and test sets from the arff files, and setting the class and attribute names.
file_train = 'testlib/PIMA/PIMA.train.0.arff'
file_test = 'testlib/PIMA/PIMA.test.0.arff'
with open(file_train) as f:
    train_data_value, attributes = loadarff(f)
    f.close()
with open(file_test) as f:
    test_data_value, attributes = loadarff(f)
    f.close()
pima_class_names= ['tested_negative','tested_positive']
pima_fnames = ["Number of times pregnant", "Plasma glucose concentration", "Diastolic blood_pressure", "Triceps skin fold thickness", "Two hour serum insulin", "Body mass index", "Diabetes pedigree function", "Age"]

# We prepare the train sets
class_names = np.array(pima_class_names)
feature_names = np.array(attributes.names())
df=pd.DataFrame(train_data_value)
df.columns = feature_names
target = df.pop('class')
target_onehot = pd.get_dummies(target)[b'2.0'] 
x_tr = df
y_tr = target_onehot
y_tr_num = [float(out_class) for out_class in y_tr] # we will need the numerical outputs to simulate regression

# And the test sets too
df_test=pd.DataFrame(test_data_value)
x_test = df_test
target_test = df_test.pop('class')
target_test_onehot = pd.get_dummies(target_test)[b'2.0']
y_test = target_test_onehot
y_test_num = [float(out_class) for out_class in y_test] # likewise

# Print some general information of the dataset
print(f'Class names for the PIMA Dataset:\n\t{", ".join(list(class_names))}\nFeature names for the PIMA Dataset:\n\t{", ".join(list(feature_names))}')
Class names for the PIMA Dataset:
	tested_negative, tested_positive
Feature names for the PIMA Dataset:
	Number_of_times_pregnant, Plasma_glucose_concentration, Diastolic_blood_pressure, Triceps_skin_fold_thickness, 2_Hour_serum_insulin, Body_mass_index, Diabetes_pedigree_function, Age, class
InĀ [5]:
## this code is developed by Jose Maria Alonso-Moral

# Check how the tabular data looks
df.head(n=10)
Out[5]:
Number_of_times_pregnant Plasma_glucose_concentration Diastolic_blood_pressure Triceps_skin_fold_thickness 2_Hour_serum_insulin Body_mass_index Diabetes_pedigree_function Age
0 5.0 103.0 108.0 37.0 0.0 39.2 0.305 65.0
1 7.0 103.0 66.0 32.0 0.0 39.1 0.344 31.0
2 10.0 101.0 76.0 48.0 180.0 32.9 0.171 63.0
3 5.0 139.0 64.0 35.0 140.0 28.6 0.411 26.0
4 0.0 74.0 52.0 10.0 36.0 27.8 0.269 22.0
5 5.0 136.0 82.0 0.0 0.0 0.0 0.640 69.0
6 3.0 176.0 86.0 27.0 156.0 33.3 1.154 52.0
7 3.0 128.0 72.0 25.0 190.0 32.4 0.549 27.0
8 7.0 184.0 84.0 33.0 0.0 35.5 0.355 41.0
9 3.0 99.0 62.0 19.0 74.0 21.8 0.279 26.0
InĀ [6]:
## this code is developed by Jose Maria Alonso-Moral

# Take a look at the output class distribution for the test set. Remember that 1.0 means tested_negative and 2.0 means tested_positive.
df_test=pd.DataFrame(test_data_value)
plt.figure(figsize=(12,6))
sns.countplot(x='class', data=df_test)
Out[6]:
<Axes: xlabel='class', ylabel='count'>
No description has been provided for this image

3.2. Model building¶

Now, we build some ML models over the dataset we just imported: two Decision Trees (which are interpretable-by-design) and a Random Forest (which is opaque). We will also perform a simple validation of their performance.

Some notes about the model creation:

  • We are using the default parameters in most situations; changing them will change the performance and, potentially, the interpretations you can extract from the trees. For the purposes of this practical session, we will not go very in-depth about it other than showing the differences that appear between an unlimited Decision Tree and one that has a hard-cap on its depth.
  • You will see below that we train Regressor equivalents of all three models. This is done to show that the approach we follow during this section works just fine for that setting too, but it is otherwise not correct. As mentioned on the previous section, the dataset we are employing is a binary classification dataset, and even though that type of problem can be approximated by a regression on the [0.0, 1.0] range, it is generally improper to utilize a dataset on a context it was not designed for. Always be mindful of the type of data you are working with!

You can check official documentation of the sklearn models we are going to use here: DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor and RandomForestRegressor.

InĀ [7]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

# We fit a Decision Tree, a Decision Tree of limited depth, and a Random Forest.
dtc = tree.DecisionTreeClassifier()
dtc.fit(x_tr, y_tr)

dtc5 = tree.DecisionTreeClassifier(max_depth=5)
dtc5.fit(x_tr, y_tr)

rfc = RandomForestClassifier(n_estimators=1000)
rfc.fit(x_tr, y_tr)

models = [dtc, dtc5, rfc]
models_names = ['TREE', 'TREE5', 'RF']

# Cross-validation for the models just fit above 
print("1) Cross-validation (over training data)")
scorings = ['accuracy', 'f1']  # For binary classification
nF= 5
for model, model_name in zip(models, models_names):
    cv_results = cross_validate(model, x_tr, y_tr, cv=nF, 
                            scoring=scorings,
                            return_train_score=False)
    print(f'\n\t{model_name}:')
    print(f'\t\tCorrect Classification Rate [Average (St. Dev)] = {np.mean(cv_results["test_accuracy"]):.3f} ({np.std(cv_results["test_accuracy"]):.3f})')
    print(f'\t\tF-Score [Average (St. Dev)] = {np.mean(cv_results["test_f1"]):.3f} ({np.std(cv_results["test_f1"]):.3f})')

# Test with unknown instances
print("\n\n2) Test (with previously unseen data)")
models_acc=[]
target_names = ['class 1', 'class 2']
for model, model_name in zip(models, models_names):
    sc = round(model.score(x_test, y_test), 3) # we round the accuracy to 3 decimals and append it to the list of model accuracies for the test set
    models_acc.append(sc)
    # Mean accuracy of self.predict(x_test) wrt y_test
    print(f'\n\t{model_name}:')
    print(f'\t\tCorrect Classification Rate: {models_acc[-1]:.3f}')
    y_pred= model.predict(x_test)
    print(classification_report(y_test, model.predict(x_test), target_names=target_names))
1) Cross-validation (over training data)

	TREE:
		Correct Classification Rate [Average (St. Dev)] = 0.690 (0.034)
		F-Score [Average (St. Dev)] = 0.550 (0.031)

	TREE5:
		Correct Classification Rate [Average (St. Dev)] = 0.748 (0.024)
		F-Score [Average (St. Dev)] = 0.618 (0.026)

	RF:
		Correct Classification Rate [Average (St. Dev)] = 0.758 (0.031)
		F-Score [Average (St. Dev)] = 0.625 (0.057)


2) Test (with previously unseen data)

	TREE:
		Correct Classification Rate: 0.688
              precision    recall  f1-score   support

     class 1       0.78      0.72      0.75        50
     class 2       0.55      0.63      0.59        27

    accuracy                           0.69        77
   macro avg       0.67      0.67      0.67        77
weighted avg       0.70      0.69      0.69        77


	TREE5:
		Correct Classification Rate: 0.727
              precision    recall  f1-score   support

     class 1       0.78      0.80      0.79        50
     class 2       0.62      0.59      0.60        27

    accuracy                           0.73        77
   macro avg       0.70      0.70      0.70        77
weighted avg       0.73      0.73      0.73        77


	RF:
		Correct Classification Rate: 0.805
              precision    recall  f1-score   support

     class 1       0.89      0.80      0.84        50
     class 2       0.69      0.81      0.75        27

    accuracy                           0.81        77
   macro avg       0.79      0.81      0.79        77
weighted avg       0.82      0.81      0.81        77


InĀ [8]:
## this code is developed by Pablo Miguel Perez-Ferreiro

# Keep in mind we need to use the numerical version of the output for regression equivalents
dtr = tree.DecisionTreeRegressor()
dtr.fit(x_tr, y_tr_num)

dtr5 = tree.DecisionTreeRegressor(max_depth=5)
dtr5.fit(x_tr, y_tr_num)

rfr = RandomForestRegressor(n_estimators=1000)
rfr.fit(x_tr, y_tr_num)

models_r = [dtr, dtr5, rfr]
models_r_names = ['TREE-R', 'TREE5-R', 'RF-R']

# We don't validate for Regressors; it can be done, but it does not fit this problem.

3.3. Visualization¶

An advantage of the interpretable models we just trained (i.e., the Decision Trees) is that they can be easily visualized. When DTs are reasonably shallow, this may suffice to explain:

  • Their global behavior, as the general structure of the tree provides a certain understanding of the tree's priorities when it comes to classifying instances. In this sense, a Decision Tree always gauges the importance of its features (implicitly): consider that most algorithms for decision tree creation build the decision splits based on information gain measures, which means that more informative features will be placed higher on the tree's branches.
  • Their local behavior, as any given instance will follow a certain path until reaching a leaf node. This path of conditions can be understood as an ad-hoc rule that justifies the instance's classification.

These two properties are limited, however, by the features of the tree holding a degree of meaning, i.e., highly complex features may not be easily understood, thus limiting their usefulness as an explanatory tool. Plain visualization can also be unwieldy when DTs are very deep, thus nesting many splits. You can check that effect by comparing the graphical representations below (unlimited depth vs. limited depth); in them each node lists:

  • The split condition. Leaf nodes omit this, as there are no more conditions to check.
  • The Gini index, which measures inequality and is the measure that guides splits. This can be changed for entropy during model creation; check the documentation previously linked.
  • The number of samples that have 'reached' that node.
  • The distribution of classes amongst those samples: [A, B] represents [tested_negative, tested_positive].
InĀ [9]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Visualizing entire decision tree
print(export_text(dtc, feature_names=pima_fnames))
dot_data = tree.export_graphviz(dtc, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames) 
graph = graphviz.Source(dot_data) 
graph#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 143.50
|   |--- Body mass index <= 27.35
|   |   |--- Plasma glucose concentration <= 105.50
|   |   |   |--- class: False
|   |   |--- Plasma glucose concentration >  105.50
|   |   |   |--- Body mass index <= 9.80
|   |   |   |   |--- Number of times pregnant <= 7.00
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Number of times pregnant >  7.00
|   |   |   |   |   |--- class: True
|   |   |   |--- Body mass index >  9.80
|   |   |   |   |--- Plasma glucose concentration <= 108.50
|   |   |   |   |   |--- Age <= 50.00
|   |   |   |   |   |   |--- Diastolic blood_pressure <= 65.00
|   |   |   |   |   |   |   |--- Two hour serum insulin <= 52.00
|   |   |   |   |   |   |   |   |--- Number of times pregnant <= 1.50
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Number of times pregnant >  1.50
|   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Two hour serum insulin >  52.00
|   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |--- Diastolic blood_pressure >  65.00
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Age >  50.00
|   |   |   |   |   |   |--- class: True
|   |   |   |   |--- Plasma glucose concentration >  108.50
|   |   |   |   |   |--- Plasma glucose concentration <= 128.50
|   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Plasma glucose concentration >  128.50
|   |   |   |   |   |   |--- Number of times pregnant <= 2.50
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |--- Number of times pregnant >  2.50
|   |   |   |   |   |   |   |--- Number of times pregnant <= 4.50
|   |   |   |   |   |   |   |   |--- Body mass index <= 23.45
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Body mass index >  23.45
|   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Number of times pregnant >  4.50
|   |   |   |   |   |   |   |   |--- class: False
|   |--- Body mass index >  27.35
|   |   |--- Plasma glucose concentration <= 99.50
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- Diastolic blood_pressure <= 83.50
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Diastolic blood_pressure >  83.50
|   |   |   |   |   |--- Diabetes pedigree function <= 0.28
|   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Diabetes pedigree function >  0.28
|   |   |   |   |   |   |--- class: False
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 27.50
|   |   |   |   |   |--- Two hour serum insulin <= 35.00
|   |   |   |   |   |   |--- Triceps skin fold thickness <= 37.50
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |--- Triceps skin fold thickness >  37.50
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Two hour serum insulin >  35.00
|   |   |   |   |   |   |--- class: True
|   |   |   |   |--- Age >  27.50
|   |   |   |   |   |--- Plasma glucose concentration <= 28.50
|   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Plasma glucose concentration >  28.50
|   |   |   |   |   |   |--- Diabetes pedigree function <= 0.80
|   |   |   |   |   |   |   |--- Age <= 42.50
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.17
|   |   |   |   |   |   |   |   |   |--- Triceps skin fold thickness <= 18.50
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |--- Triceps skin fold thickness >  18.50
|   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.17
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |--- Age >  42.50
|   |   |   |   |   |   |   |   |--- Triceps skin fold thickness <= 24.00
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Triceps skin fold thickness >  24.00
|   |   |   |   |   |   |   |   |   |--- Triceps skin fold thickness <= 31.50
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |--- Triceps skin fold thickness >  31.50
|   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |--- Diabetes pedigree function >  0.80
|   |   |   |   |   |   |   |--- Number of times pregnant <= 3.00
|   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |--- Number of times pregnant >  3.00
|   |   |   |   |   |   |   |   |--- class: True
|   |   |--- Plasma glucose concentration >  99.50
|   |   |   |--- Age <= 30.50
|   |   |   |   |--- Diastolic blood_pressure <= 22.00
|   |   |   |   |   |--- Plasma glucose concentration <= 117.00
|   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Plasma glucose concentration >  117.00
|   |   |   |   |   |   |--- class: True
|   |   |   |   |--- Diastolic blood_pressure >  22.00
|   |   |   |   |   |--- Body mass index <= 45.40
|   |   |   |   |   |   |--- Diabetes pedigree function <= 0.51
|   |   |   |   |   |   |   |--- Number of times pregnant <= 7.00
|   |   |   |   |   |   |   |   |--- Plasma glucose concentration <= 138.50
|   |   |   |   |   |   |   |   |   |--- Two hour serum insulin <= 22.00
|   |   |   |   |   |   |   |   |   |   |--- Body mass index <= 27.70
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |   |--- Body mass index >  27.70
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 6
|   |   |   |   |   |   |   |   |   |--- Two hour serum insulin >  22.00
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration <= 133.00
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration >  133.00
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |   |   |   |--- Plasma glucose concentration >  138.50
|   |   |   |   |   |   |   |   |   |--- Two hour serum insulin <= 65.00
|   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |--- Two hour serum insulin >  65.00
|   |   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure <= 64.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure >  64.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Number of times pregnant >  7.00
|   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Diabetes pedigree function >  0.51
|   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.55
|   |   |   |   |   |   |   |   |--- Body mass index <= 38.25
|   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.53
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration <= 110.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration >  110.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.53
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |--- Body mass index >  38.25
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.55
|   |   |   |   |   |   |   |   |--- Body mass index <= 32.70
|   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 1.07
|   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  1.07
|   |   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 1.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  1.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Body mass index >  32.70
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure <= 69.00
|   |   |   |   |   |   |   |   |   |   |--- Body mass index <= 34.85
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |   |--- Body mass index >  34.85
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure >  69.00
|   |   |   |   |   |   |   |   |   |   |--- Age <= 21.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |   |--- Age >  21.50
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |--- Body mass index >  45.40
|   |   |   |   |   |   |--- Age <= 22.50
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |--- Age >  22.50
|   |   |   |   |   |   |   |--- class: True
|   |   |   |--- Age >  30.50
|   |   |   |   |--- Diabetes pedigree function <= 0.53
|   |   |   |   |   |--- Age <= 47.50
|   |   |   |   |   |   |--- Diabetes pedigree function <= 0.13
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |--- Diabetes pedigree function >  0.13
|   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.49
|   |   |   |   |   |   |   |   |--- Body mass index <= 31.15
|   |   |   |   |   |   |   |   |   |--- Number of times pregnant <= 6.50
|   |   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.41
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 3
|   |   |   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.41
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |--- Number of times pregnant >  6.50
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |--- Body mass index >  31.15
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure <= 97.00
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration <= 136.50
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 9
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration >  136.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure >  97.00
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.49
|   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Age >  47.50
|   |   |   |   |   |   |--- Number of times pregnant <= 1.50
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Number of times pregnant >  1.50
|   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.14
|   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.14
|   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |--- Diabetes pedigree function >  0.53
|   |   |   |   |   |--- Number of times pregnant <= 7.50
|   |   |   |   |   |   |--- Diastolic blood_pressure <= 69.00
|   |   |   |   |   |   |   |--- Body mass index <= 28.10
|   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Body mass index >  28.10
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 1.84
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  1.84
|   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Diastolic blood_pressure >  69.00
|   |   |   |   |   |   |   |--- Body mass index <= 39.65
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.56
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure <= 72.50
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure >  72.50
|   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.56
|   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Body mass index >  39.65
|   |   |   |   |   |   |   |   |--- Two hour serum insulin <= 146.00
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Two hour serum insulin >  146.00
|   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Number of times pregnant >  7.50
|   |   |   |   |   |   |--- class: True
|--- Plasma glucose concentration >  143.50
|   |--- Plasma glucose concentration <= 154.50
|   |   |--- Diabetes pedigree function <= 0.33
|   |   |   |--- Diabetes pedigree function <= 0.18
|   |   |   |   |--- Plasma glucose concentration <= 151.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Plasma glucose concentration >  151.00
|   |   |   |   |   |--- class: False
|   |   |   |--- Diabetes pedigree function >  0.18
|   |   |   |   |--- Age <= 67.50
|   |   |   |   |   |--- Number of times pregnant <= 6.50
|   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Number of times pregnant >  6.50
|   |   |   |   |   |   |--- Diastolic blood_pressure <= 79.00
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Diastolic blood_pressure >  79.00
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |--- Age >  67.50
|   |   |   |   |   |--- class: True
|   |   |--- Diabetes pedigree function >  0.33
|   |   |   |--- Age <= 31.50
|   |   |   |   |--- Diabetes pedigree function <= 0.37
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Diabetes pedigree function >  0.37
|   |   |   |   |   |--- Plasma glucose concentration <= 145.00
|   |   |   |   |   |   |--- Triceps skin fold thickness <= 36.50
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Triceps skin fold thickness >  36.50
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Plasma glucose concentration >  145.00
|   |   |   |   |   |   |--- class: False
|   |   |   |--- Age >  31.50
|   |   |   |   |--- Plasma glucose concentration <= 152.50
|   |   |   |   |   |--- Two hour serum insulin <= 313.50
|   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Two hour serum insulin >  313.50
|   |   |   |   |   |   |--- class: False
|   |   |   |   |--- Plasma glucose concentration >  152.50
|   |   |   |   |   |--- Two hour serum insulin <= 63.00
|   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Two hour serum insulin >  63.00
|   |   |   |   |   |   |--- class: False
|   |--- Plasma glucose concentration >  154.50
|   |   |--- Body mass index <= 29.85
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- class: False
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 61.00
|   |   |   |   |   |--- Body mass index <= 27.00
|   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Body mass index >  27.00
|   |   |   |   |   |   |--- Age <= 36.50
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Age >  36.50
|   |   |   |   |   |   |   |--- class: False
|   |   |   |   |--- Age >  61.00
|   |   |   |   |   |--- class: False
|   |   |--- Body mass index >  29.85
|   |   |   |--- Diabetes pedigree function <= 0.13
|   |   |   |   |--- class: False
|   |   |   |--- Diabetes pedigree function >  0.13
|   |   |   |   |--- Age <= 44.00
|   |   |   |   |   |--- Two hour serum insulin <= 661.50
|   |   |   |   |   |   |--- Number of times pregnant <= 9.50
|   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.31
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function <= 0.29
|   |   |   |   |   |   |   |   |   |--- Number of times pregnant <= 0.50
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration <= 173.00
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration >  173.00
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |--- Number of times pregnant >  0.50
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.29
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |--- Diabetes pedigree function >  0.31
|   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Number of times pregnant >  9.50
|   |   |   |   |   |   |   |--- Body mass index <= 34.35
|   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Body mass index >  34.35
|   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |--- Two hour serum insulin >  661.50
|   |   |   |   |   |   |--- class: False
|   |   |   |   |--- Age >  44.00
|   |   |   |   |   |--- Diabetes pedigree function <= 1.16
|   |   |   |   |   |   |--- Body mass index <= 34.45
|   |   |   |   |   |   |   |--- Body mass index <= 32.20
|   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |--- Body mass index >  32.20
|   |   |   |   |   |   |   |   |--- Two hour serum insulin <= 111.50
|   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |--- Two hour serum insulin >  111.50
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure <= 65.00
|   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |   |   |   |--- Diastolic blood_pressure >  65.00
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration <= 170.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: False
|   |   |   |   |   |   |   |   |   |   |--- Plasma glucose concentration >  170.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Body mass index >  34.45
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |--- Diabetes pedigree function >  1.16
|   |   |   |   |   |   |--- Diastolic blood_pressure <= 77.00
|   |   |   |   |   |   |   |--- class: True
|   |   |   |   |   |   |--- Diastolic blood_pressure >  77.00
|   |   |   |   |   |   |   |--- class: False

Out[9]:
No description has been provided for this image

InĀ [10]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Visualizing the reduced tree
print(export_text(dtc5, feature_names=pima_fnames))
dot_data5 = tree.export_graphviz(dtc5, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames) 
graph5 = graphviz.Source(dot_data5) 
graph5#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 143.50
|   |--- Body mass index <= 27.35
|   |   |--- Plasma glucose concentration <= 105.50
|   |   |   |--- class: False
|   |   |--- Plasma glucose concentration >  105.50
|   |   |   |--- Body mass index <= 9.80
|   |   |   |   |--- Number of times pregnant <= 7.00
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Number of times pregnant >  7.00
|   |   |   |   |   |--- class: True
|   |   |   |--- Body mass index >  9.80
|   |   |   |   |--- Plasma glucose concentration <= 108.50
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Plasma glucose concentration >  108.50
|   |   |   |   |   |--- class: False
|   |--- Body mass index >  27.35
|   |   |--- Plasma glucose concentration <= 99.50
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- Diastolic blood_pressure <= 83.50
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Diastolic blood_pressure >  83.50
|   |   |   |   |   |--- class: False
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 27.50
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Age >  27.50
|   |   |   |   |   |--- class: False
|   |   |--- Plasma glucose concentration >  99.50
|   |   |   |--- Age <= 30.50
|   |   |   |   |--- Diastolic blood_pressure <= 22.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Diastolic blood_pressure >  22.00
|   |   |   |   |   |--- class: False
|   |   |   |--- Age >  30.50
|   |   |   |   |--- Diabetes pedigree function <= 0.53
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Diabetes pedigree function >  0.53
|   |   |   |   |   |--- class: True
|--- Plasma glucose concentration >  143.50
|   |--- Plasma glucose concentration <= 154.50
|   |   |--- Diabetes pedigree function <= 0.33
|   |   |   |--- Diabetes pedigree function <= 0.18
|   |   |   |   |--- Triceps skin fold thickness <= 28.50
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Triceps skin fold thickness >  28.50
|   |   |   |   |   |--- class: False
|   |   |   |--- Diabetes pedigree function >  0.18
|   |   |   |   |--- Age <= 67.50
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Age >  67.50
|   |   |   |   |   |--- class: True
|   |   |--- Diabetes pedigree function >  0.33
|   |   |   |--- Age <= 31.50
|   |   |   |   |--- Diabetes pedigree function <= 0.37
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Diabetes pedigree function >  0.37
|   |   |   |   |   |--- class: False
|   |   |   |--- Age >  31.50
|   |   |   |   |--- Plasma glucose concentration <= 152.50
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Plasma glucose concentration >  152.50
|   |   |   |   |   |--- class: False
|   |--- Plasma glucose concentration >  154.50
|   |   |--- Body mass index <= 29.85
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- class: False
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 61.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Age >  61.00
|   |   |   |   |   |--- class: False
|   |   |--- Body mass index >  29.85
|   |   |   |--- Diabetes pedigree function <= 0.13
|   |   |   |   |--- class: False
|   |   |   |--- Diabetes pedigree function >  0.13
|   |   |   |   |--- Age <= 44.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Age >  44.00
|   |   |   |   |   |--- class: True

Out[10]:
No description has been provided for this image

We can also plot the regressor trees. You'll see their format is mostly equivalent to that above, although they use squared error instead of Gini index (regression vs. classification, as said) and report the value that would be assigned for a data instance that ended up on that node.

InĀ [11]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro
# Visualizing the reduced regressor tree
print(export_text(dtr5, feature_names=pima_fnames))
dot_data5r = tree.export_graphviz(dtr5, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames) 
graph5r = graphviz.Source(dot_data5r) 
graph5r#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 143.50
|   |--- Body mass index <= 27.35
|   |   |--- Plasma glucose concentration <= 105.50
|   |   |   |--- value: [0.00]
|   |   |--- Plasma glucose concentration >  105.50
|   |   |   |--- Body mass index <= 9.80
|   |   |   |   |--- Number of times pregnant <= 7.00
|   |   |   |   |   |--- value: [0.00]
|   |   |   |   |--- Number of times pregnant >  7.00
|   |   |   |   |   |--- value: [1.00]
|   |   |   |--- Body mass index >  9.80
|   |   |   |   |--- Plasma glucose concentration <= 108.50
|   |   |   |   |   |--- value: [0.23]
|   |   |   |   |--- Plasma glucose concentration >  108.50
|   |   |   |   |   |--- value: [0.05]
|   |--- Body mass index >  27.35
|   |   |--- Plasma glucose concentration <= 99.50
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- Diastolic blood_pressure <= 83.50
|   |   |   |   |   |--- value: [0.00]
|   |   |   |   |--- Diastolic blood_pressure >  83.50
|   |   |   |   |   |--- value: [0.33]
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 27.50
|   |   |   |   |   |--- value: [0.67]
|   |   |   |   |--- Age >  27.50
|   |   |   |   |   |--- value: [0.17]
|   |   |--- Plasma glucose concentration >  99.50
|   |   |   |--- Age <= 30.50
|   |   |   |   |--- Diastolic blood_pressure <= 22.00
|   |   |   |   |   |--- value: [0.86]
|   |   |   |   |--- Diastolic blood_pressure >  22.00
|   |   |   |   |   |--- value: [0.25]
|   |   |   |--- Age >  30.50
|   |   |   |   |--- Diabetes pedigree function <= 0.53
|   |   |   |   |   |--- value: [0.45]
|   |   |   |   |--- Diabetes pedigree function >  0.53
|   |   |   |   |   |--- value: [0.80]
|--- Plasma glucose concentration >  143.50
|   |--- Plasma glucose concentration <= 154.50
|   |   |--- Diabetes pedigree function <= 0.33
|   |   |   |--- Diabetes pedigree function <= 0.18
|   |   |   |   |--- Plasma glucose concentration <= 151.00
|   |   |   |   |   |--- value: [1.00]
|   |   |   |   |--- Plasma glucose concentration >  151.00
|   |   |   |   |   |--- value: [0.00]
|   |   |   |--- Diabetes pedigree function >  0.18
|   |   |   |   |--- Age <= 67.50
|   |   |   |   |   |--- value: [0.08]
|   |   |   |   |--- Age >  67.50
|   |   |   |   |   |--- value: [1.00]
|   |   |--- Diabetes pedigree function >  0.33
|   |   |   |--- Age <= 31.50
|   |   |   |   |--- Diabetes pedigree function <= 0.37
|   |   |   |   |   |--- value: [1.00]
|   |   |   |   |--- Diabetes pedigree function >  0.37
|   |   |   |   |   |--- value: [0.12]
|   |   |   |--- Age >  31.50
|   |   |   |   |--- Plasma glucose concentration <= 152.50
|   |   |   |   |   |--- value: [0.94]
|   |   |   |   |--- Plasma glucose concentration >  152.50
|   |   |   |   |   |--- value: [0.33]
|   |--- Plasma glucose concentration >  154.50
|   |   |--- Body mass index <= 29.85
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- value: [0.00]
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 61.00
|   |   |   |   |   |--- value: [0.80]
|   |   |   |   |--- Age >  61.00
|   |   |   |   |   |--- value: [0.00]
|   |   |--- Body mass index >  29.85
|   |   |   |--- Diabetes pedigree function <= 0.13
|   |   |   |   |--- value: [0.00]
|   |   |   |--- Diabetes pedigree function >  0.13
|   |   |   |   |--- Age <= 44.00
|   |   |   |   |   |--- value: [0.93]
|   |   |   |   |--- Age >  44.00
|   |   |   |   |   |--- value: [0.76]

Out[11]:
No description has been provided for this image

Be aware that, convenient as it may be, this type of interpretation can't be performed for all (or most) ML approaches. For instance, the Random Forest we just trained can't be visualized like this, and in fact is not interpretable by itself. However, Random Forests are usually very strong in terms of predictive performance, and thus ideally, we would like to make them at least a bit understandable. We will need a different tool for this.


Exercise I1.1¶

Try to build some alternate versions of the Decision Trees we obtained, by changing the parameters on the constructor call. You don't need to do this for Regressors, so refer to this documentation. You don't need to be exhaustive: trying two or three is enough. You may keep limiting their depth to keep their interpretation manageable (but you don't need to).

Plot these new decision trees and compare them to the two trees (dtc and dtc5) given as illustrative examples in the previous cells. Are they significantly different? Extract some conclusions as to why/why not, and summarize a few key insights onto the diabetes prediction problem you can distill from the trends shown by your set of trees.

InĀ [Ā ]:
# this code is developed by Miguel Leal and Gian Paolo Bulleddu
#==========================================================
# We will now fit some alternative Decision Trees models with different hyperparameters, 
# so that we can check how the interpretability metrics and the performance metrics change with the new hyperparameters.
# We will fit a Decision Tree with the same maximum depth as before, but using the 'entropy' criterion instead of 'gini', 
# and another Decision Tree with the same maximum depth but  with a minimum number of samples per leaf of 3 , this will reduce the complexity of the tree,
# and will increase its interpretability.
#==========================================================
dtcv1 = tree.DecisionTreeClassifier(max_depth=5, criterion='entropy')
dtcv1.fit(x_tr, y_tr)

dtcv2 = tree.DecisionTreeClassifier(max_depth=5, min_samples_leaf=3)
dtcv2.fit(x_tr, y_tr)

models_alt = [dtcv1, dtcv2]
models_names_alt = ['TREEv1', 'TREEv2']

# Cross-validation for the models just fit above 
print("1) Cross-validation (over training data)")
scorings = ['accuracy', 'f1']  # For binary classification
nF= 5
for model, model_name in zip(models_alt, models_names_alt):
    cv_results = cross_validate(model, x_tr, y_tr, cv=nF, 
                            scoring=scorings,
                            return_train_score=False)
    print(f'\n\t{model_name}:')
    print(f'\t\tCorrect Classification Rate [Average (St. Dev)] = {np.mean(cv_results["test_accuracy"]):.3f} ({np.std(cv_results["test_accuracy"]):.3f})')
    print(f'\t\tF-Score [Average (St. Dev)] = {np.mean(cv_results["test_f1"]):.3f} ({np.std(cv_results["test_f1"]):.3f})')

# Test with unknown instances
print("\n\n2) Test (with previously unseen data)")
models_acc_alt=[]
target_names = ['class 1', 'class 2']
for model, model_name in zip(models_alt, models_names_alt):
    sc = round(model.score(x_test, y_test), 3) # we round the accuracy to 3 decimals and append it to the list of model accuracies for the test set
    models_acc_alt.append(sc)
    # Mean accuracy of self.predict(x_test) wrt y_test
    print(f'\n\t{model_name}:')
    print(f'\t\tCorrect Classification Rate: {models_acc_alt[-1]:.3f}')
    y_pred= model.predict(x_test)
    print(classification_report(y_test, model.predict(x_test), target_names=target_names))
1) Cross-validation (over training data)

	TREEv1:
		Correct Classification Rate [Average (St. Dev)] = 0.729 (0.051)
		F-Score [Average (St. Dev)] = 0.576 (0.087)

	TREEv2:
		Correct Classification Rate [Average (St. Dev)] = 0.737 (0.029)
		F-Score [Average (St. Dev)] = 0.591 (0.046)


2) Test (with previously unseen data)

	TREEv1:
		Correct Classification Rate: 0.805
              precision    recall  f1-score   support

     class 1       0.86      0.84      0.85        50
     class 2       0.71      0.74      0.73        27

    accuracy                           0.81        77
   macro avg       0.79      0.79      0.79        77
weighted avg       0.81      0.81      0.81        77


	TREEv2:
		Correct Classification Rate: 0.727
              precision    recall  f1-score   support

     class 1       0.78      0.80      0.79        50
     class 2       0.62      0.59      0.60        27

    accuracy                           0.73        77
   macro avg       0.70      0.70      0.70        77
weighted avg       0.73      0.73      0.73        77

InĀ [Ā ]:
# this code is developed by Miguel Leal and Gian Paolo Bulleddu
# Visualizing the alternative reduced tree with 'entropy' criterion
print(export_text(dtcv1, feature_names=pima_fnames))
dot_datav1 = tree.export_graphviz(dtcv1, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames) 
graphv1 = graphviz.Source(dot_datav1) 
graphv1#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- Plasma glucose concentration <= 127.50
|   |--- Body mass index <= 26.45
|   |   |--- Body mass index <= 9.10
|   |   |   |--- Number of times pregnant <= 7.50
|   |   |   |   |--- class: False
|   |   |   |--- Number of times pregnant >  7.50
|   |   |   |   |--- class: True
|   |   |--- Body mass index >  9.10
|   |   |   |--- Diabetes pedigree function <= 0.68
|   |   |   |   |--- class: False
|   |   |   |--- Diabetes pedigree function >  0.68
|   |   |   |   |--- Diabetes pedigree function <= 0.71
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Diabetes pedigree function >  0.71
|   |   |   |   |   |--- class: False
|   |--- Body mass index >  26.45
|   |   |--- Age <= 28.50
|   |   |   |--- Body mass index <= 30.95
|   |   |   |   |--- Number of times pregnant <= 7.00
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Number of times pregnant >  7.00
|   |   |   |   |   |--- class: True
|   |   |   |--- Body mass index >  30.95
|   |   |   |   |--- Diastolic blood_pressure <= 37.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Diastolic blood_pressure >  37.00
|   |   |   |   |   |--- class: False
|   |   |--- Age >  28.50
|   |   |   |--- Plasma glucose concentration <= 99.50
|   |   |   |   |--- Plasma glucose concentration <= 28.50
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Plasma glucose concentration >  28.50
|   |   |   |   |   |--- class: False
|   |   |   |--- Plasma glucose concentration >  99.50
|   |   |   |   |--- Diabetes pedigree function <= 0.56
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Diabetes pedigree function >  0.56
|   |   |   |   |   |--- class: True
|--- Plasma glucose concentration >  127.50
|   |--- Plasma glucose concentration <= 154.50
|   |   |--- Body mass index <= 28.85
|   |   |   |--- Number of times pregnant <= 1.50
|   |   |   |   |--- class: False
|   |   |   |--- Number of times pregnant >  1.50
|   |   |   |   |--- Body mass index <= 23.45
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Body mass index >  23.45
|   |   |   |   |   |--- class: False
|   |   |--- Body mass index >  28.85
|   |   |   |--- Diabetes pedigree function <= 0.44
|   |   |   |   |--- Body mass index <= 41.80
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Body mass index >  41.80
|   |   |   |   |   |--- class: True
|   |   |   |--- Diabetes pedigree function >  0.44
|   |   |   |   |--- Age <= 30.00
|   |   |   |   |   |--- class: False
|   |   |   |   |--- Age >  30.00
|   |   |   |   |   |--- class: True
|   |--- Plasma glucose concentration >  154.50
|   |   |--- Body mass index <= 29.85
|   |   |   |--- Age <= 25.50
|   |   |   |   |--- class: False
|   |   |   |--- Age >  25.50
|   |   |   |   |--- Age <= 61.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Age >  61.00
|   |   |   |   |   |--- class: False
|   |   |--- Body mass index >  29.85
|   |   |   |--- Diastolic blood_pressure <= 67.00
|   |   |   |   |--- class: True
|   |   |   |--- Diastolic blood_pressure >  67.00
|   |   |   |   |--- Two hour serum insulin <= 661.50
|   |   |   |   |   |--- class: True
|   |   |   |   |--- Two hour serum insulin >  661.50
|   |   |   |   |   |--- class: False

Out[Ā ]:
No description has been provided for this image
InĀ [Ā ]:
# this code is developed by Miguel Leal and Gian Paolo Bulleddu
# Visualizing the alternative reduced tree with min_samples_leaf=3
print(export_text(dtcv2, feature_names=pima_fnames))
dot_datav2 = tree.export_graphviz(dtcv2, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=pima_fnames) 
graphv2 = graphviz.Source(dot_datav2) 
graphv2#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.

Explanation:

In this exercise we have built two decision tree classifiers by variyng some hyperparameters. In both models we have kept the maximum depth to 5 then in the first model we have chamged the Gini impurity criterion to Entropy while in the second model we have set the minimum number of leaf to three. The purpose of this task is to evaluate how the decision logic differs from the predefined decision trees. After the training of the newly created models we have analyzed the results and down below are our conclusion.

First of all we have not detected big structure difference between the new (dtcv1, dtcv2) and the original trees (dtc, dtc5). The plasma glucose concentration feature most frequently appears as the root node or very close to it in all models, this suggest that this is the most important feature. The Body mass index too is often selected in the upper levels of the trees frequently acting as a secondary decision variable . Ageand diabetes pedigree are most frequently selected at an intermediate tree level. The rest of the dataset features like insulin,diastolic blood pressure,triceps skin fold thickness have a tendency to appear only in the deeper levels of the trees or not at all so we can conclude that those are the less important features.

The Entropy splitting criterion has not lead the model to big changes in tree structures and split thresholds, moreover it has not changed the importance or the ordering of the dataset features. The entropy-based tree selects similar variables at the top of the tree and produces comparable decision paths,regardless of the impurity criterion used the the most important features can be identified by the models.

Limiting the minimum number of sample per leaf produces simpler trees , however even reducing highly specific leaves and eliminating splits based on few data observations the structure of the tree looks almost unchanged, especially at tree top levels. The main difference in this model tree does not have any node branched by the feature times being pregnant.

In conclusion all trained trees don't show big differences in structures and logic, the prediction of diabetes diagnosis looks mainly driven by features plasma glucose concentration and Body mass index,Ageand diabetes pedigree are important as well but with a lower strenght in routing the model to a good prediction.

3.4. Using SHAP Values for interpretability¶

We can utilize SHAP values to tackle some of the shortcomings mentioned on the previous section:

  • SHAP is much more practical when dealing with, for example, a very deep tree with many nested conditions, as it offers a single importance value per feature. In this sense, it summarizes the explanatory information that is spread among the branches of the tree, although SHAP values do not necessarily offer the same interpretation you would extract from a tree's branches.
  • SHAP, as a post-hoc explainability method, can endow non-interpretable models (such as Random Forests) with interpretability.

SHAP still has a problem with semantic significance (i.e. features having meaning). Because it only gauges the importance of features, we are still fully dependant on those features (and their impact) offering understandable insight on the problem.

We'll see how to generate both global and local explanations with SHAP:

Global explanations¶

InĀ [15]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

# Building the explainers and computing Shapley values
# Unbound DT
explainerTreeC = shap.TreeExplainer(dtc)
shap_values_DTC = explainerTreeC(x_test)
# Limited-depth DT
explainerTree5C = shap.TreeExplainer(dtc5)
shap_values_DTC5 = explainerTree5C(x_test)
# Random Forest
explainerTreeRFC = shap.TreeExplainer(rfc)
shap_values_RFC = explainerTreeRFC(x_test)

explainers = [explainerTreeC, explainerTree5C, explainerTreeRFC]
InĀ [16]:
# A general commentary on the shap_values_{model} indexing: the SHAP library interprets any classification as multi-class classification (you will see later that indexing is not needed for regressors), which means we need to specify
# the class we're explaining for. Then, [a, b, c] ---> a selects the rows (instances on the dataset), b selects the columns (variables on the dataset), c selects the class (so, 0 for tested_negative, 1 for tested_positive).
# You may alter the plots to work with limited instance ranges, different column sets, or even for the opposite class.

shap.summary_plot(shap_values_DTC[:,:,1])
No description has been provided for this image
InĀ [17]:
shap.summary_plot(shap_values_DTC5[:,:,1])
No description has been provided for this image
InĀ [18]:
shap.summary_plot(shap_values_RFC[:,:,1])
No description has been provided for this image

InĀ [19]:
## this code is developed by Pablo Miguel Perez-Ferreiro

# Equivalent for Regressors
# Unbound DT
explainerTreeR = shap.TreeExplainer(dtr)
shap_values_DTR = explainerTreeR(x_test)
# Limited-depth DT
explainerTree5R = shap.TreeExplainer(dtr5)
shap_values_DTR5 = explainerTree5R(x_test)
# Random Forest
explainerTreeRFR = shap.TreeExplainer(rfr)
shap_values_RFR = explainerTreeRFR(x_test)

explainers_r = [explainerTreeR, explainerTree5R, explainerTreeRFR]
InĀ [20]:
shap.summary_plot(shap_values_DTR)
No description has been provided for this image
InĀ [21]:
shap.summary_plot(shap_values_DTR5)
No description has been provided for this image
InĀ [22]:
shap.summary_plot(shap_values_RFR)
No description has been provided for this image

Exercise I1.2¶

Now you have an alternative explanation tool against which you can contrast the insight you extracted on Exercise I1.1, and some knowledge onto how the Random Forest is predicting. Try to interpret the graphs corresponding to the Classifier alternatives:

  • Do the SHAP values for the Decision Trees match your inspection of their graphical representation?
  • Do you find significant differences between the Decision Trees and the Random Forest?

It is a good thing if you read SHAP documentation in order to understand better how to interpret the generated plots. Keep in mind that the answer for any of these questions may be negative, but you should still try to theorise as to why (justify your responses).

Do the SHAP values for the Decision Trees match your inspection of their graphical representation?

Yes, the SHAP values confirm what we saw in the decision tree. Plasma glucose concentration has the highest impact on the predictions, just like it appears at the top of the tree splits. Other features, such as body mass index, age, and diabetes pedigree function, also show significant influence, which matches their presence in the higher levels of the tree. The SHAP summary plot shows that high values (red points) of plasma glucose generally push predictions toward the positive class (dots on the right), consistent with the tree’s decision rules.

Do you find significant differences between the Decision Trees and the Random Forest?

Yes, there are differences between the two types of models. The Random Forest exhibits less dispersion in SHAP values, with points more densely grouped and fewer outliers. This shows the ensemble nature of the Random Forest, where predictions are averaged over many trees, reducing variability and sensitivity to individual data points. In contrast, single Decision Trees tend to produce more scattered SHAP values and sharper transitions due to their reliance on hard decision thresholds.

Local explanations¶

Keep in mind that, for a local explanation, we may justify a wrong prediction. This can be misleading, so please pay close attention to it.

InĀ [23]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

for instance in [0, 24, 59, 76]: # specific cases, feel free to alter them
    print('----------------------------------------------------------------------------------------------------------------------------')
    print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_test.iloc[instance]}')
    for model, model_name, explainer in zip(models, models_names, explainers):
        print(f'The predicted output class by model {model_name} is: {model.predict(x_test.iloc[instance : instance+1])[0]}')
        shap_values = explainer(x_test.iloc[instance : instance+1])
        shap.summary_plot(shap_values[:,:,1])
        print('----------------------------------------------------------------------------------------------------------------------------')
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 0:
The real output class is: False
The predicted output class by model TREE is: True
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model TREE5 is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 24:
The real output class is: False
The predicted output class by model TREE is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model TREE5 is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 59:
The real output class is: True
The predicted output class by model TREE is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model TREE5 is: False
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF is: True
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 76:
The real output class is: True
The predicted output class by model TREE is: True
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model TREE5 is: True
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF is: True
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------

InĀ [24]:
## this code is developed by Pablo Miguel Perez-Ferreiro

# We will only execute one instance for regression to avoid cluttering the notebook.
for instance in [0]:
    print('----------------------------------------------------------------------------------------------------------------------------')
    print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_test_num[instance]}')
    for model, model_name, explainer in zip(models_r, models_r_names, explainers_r):
        print(f'The predicted output class by model {model_name} is: {model.predict(x_test.iloc[instance : instance+1])[0]}')
        shap_values = explainer(x_test.iloc[instance : instance+1])
        shap.summary_plot(shap_values)
        print('----------------------------------------------------------------------------------------------------------------------------')
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 0:
The real output class is: 0.0
The predicted output class by model TREE-R is: 1.0
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model TREE5-R is: 0.25
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF-R is: 0.381
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------

Exercise I1.3¶

Have a look at SHAP's API reference. So far, we have only used the basic summary plot, but one of the library's greatest strengths is its rich plotting environment. Try your hand at generating new plots (both global and local) following the code given during the section. Do you think they offer better information than the summaries we were using? Which plotting options do you think would be the best at explaining the problem at hand to a layperson? Would your answer change if the explanations were meant for an expert?

InĀ [25]:
# You may need to execute the following for some of the plots to work
shap.initjs()
No description has been provided for this image
InĀ [26]:
for instance in [24]: # specific cases, feel free to alter them
    print('----------------------------------------------------------------------------------------------------------------------------')
    print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_test.iloc[instance]}')
    for model, model_name, explainer in zip(models, models_names, explainers):
        print(f'The predicted output class by model {model_name} is: {model.predict(x_test.iloc[instance : instance+1])[0]}')
        shap_values = explainer(x_test.iloc[instance : instance+1])
        
        # select class 1
        shap_class = shap_values[:, :, 1]

        print("Bar:")
        shap.plots.bar(shap_class)
        print("Waterfall:")
        shap.plots.waterfall(shap_class[0])

        print("Partial dependence:")
        shap.plots.partial_dependence("Age", model.predict ,x_test)
        print('----------------------------------------------------------------------------------------------------------------------------')
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 24:
The real output class is: False
The predicted output class by model TREE is: False
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model TREE5 is: False
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF is: False
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------

Explanation:

Yes, these plots do offer better and more complete information than the summary plot alone, however it has to be seen as a complementary rather than substitutive way. The SHAP summary plot helps in understandin the global behavior of the model, it highlights which the most importance features are and how their values influence the predictions. However, it is quite abstract and does not explain individual model decisions.

The bar and waterfall plots provide clearer local explanations. The bar plot makes it easy to see which features contribute the most to a single prediction, while starting from a baseline ,the waterfall plot explicitally shows how each feature influences the model’s output toward or away from the predicted class. This makes the decision process more interpretable at the instance level and helps understand why a specific patient is classified as diabetic or not. Partial dependence plots add another perspective by showing how a feature affects predictions on average across the dataset, which can be useful to understand general trends.

For a layperson, the bar and especially the waterfall plots are the most effective. They are intuitive, visually clear, and allow to easily explain which factors increase the risk and which decrease it for a given individual. In contrast, the summary plot and partial dependence plots can be harder to interpret without technical background.

If the explanations were meant for an expert, the answer would change. An expert audience would benefit more from the summary plot and partial dependence plots, as these provide global insights into model behavior, feature interactions, and overall consistency. In that case, local plots would still be useful, but mainly as a complement for analyzing specific cases or debugging the model rather than as the primary explanation tool.

3.5. Trade-offs¶

We have now studied two ways of sheding light on how to make ML models more interpretable: either through direct inspection when the model allows it (because it is interpretable-by-design), or applying SHAP as a post-hoc method. It is important, however, to note that interpretability, while very important, is not to be gauged in a vacuum. A very interpretable model that performs poorly is useless: it may explain its reasoning, but the reasoning is flawed in itself and there is no point in understanding it.

In a less extreme situation, we may find ourselves faced with a decision between a highly interpretable model that performs reasonably well, and a less interpretable model that performs even better. Whatever we choose, we have a trade-off, and a good way to decide is through the construction of a Pareto front that allows us to objectively compare the relative virtues of all options, presenting a performance metric against an interpretability metric:

InĀ [27]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

# Building and visualizing the Pareto front with predictive accuracy versus number of leaf nodes, which can be understood as a surrogate for interpretability (remember
# how the largest tree was much harder to visually inspect)

limits = [min([dtc.get_n_leaves(), dtc5.get_n_leaves()])*0.75, max([dtc.get_n_leaves(), dtc5.get_n_leaves()])*1.5]
# Random Forest is a black box model and there is no easy, fair way to give it a nominal leaf node amount. We settle here for multiplying the largest tree x1.5 (penalty for opaqueness).
x_axis = [dtc.get_n_leaves(), dtc5.get_n_leaves(), limits[1]]

print("The data for the Pareto front is as follows:")
for name, accuracy, leaves in zip(models_names, models_acc, x_axis):
    print(f'\t-For model {name}, accuracy is {accuracy} with a total of {leaves} leaf nodes.')
    
plt.figure(figsize=[15,10])
plot_pareto_front(x_axis, models_acc, models_names, 'Accuracy (Classification Ratio)','Interpretability (Num of rules / leaves)', limits[0], limits[1])
The data for the Pareto front is as follows:
	-For model TREE, accuracy is 0.688 with a total of 132 leaf nodes.
	-For model TREE5, accuracy is 0.727 with a total of 27 leaf nodes.
	-For model RF, accuracy is 0.805 with a total of 198.0 leaf nodes.
No description has been provided for this image
InĀ [28]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

# Building and visualizing the Pareto front with Accuracy versus Shap Length

x_axis = [get_shap_explanation_length(explainer.shap_values(x_test)[:,:,1]) for explainer in explainers]
limits = [min(x_axis)*0.99, max(x_axis)*1.01]

print("The data for the Pareto front is as follows:")
for name, accuracy, shap_l in zip(models_names, models_acc, x_axis):
    print(f'\t-For model {name}, accuracy is {accuracy} with a SHAP length of {shap_l}.')

plt.figure(figsize=[15,10])
plot_pareto_front(x_axis,models_acc,models_names,'Accuracy (Classification Ratio)','Interpretability (Shap Length)',limits[0],limits[1])
The data for the Pareto front is as follows:
	-For model TREE, accuracy is 0.688 with a SHAP length of 558.
	-For model TREE5, accuracy is 0.727 with a SHAP length of 561.
	-For model RF, accuracy is 0.805 with a SHAP length of 557.
No description has been provided for this image

4. A more integral approach: InterpretML¶

In the previous section, we followed the full process of building an intelligent system and trying to understand how it works. However, thanks to the growing interest in AI trustworthiness and explainability (which motivates this subject!), we now have many tools at our disposal that make the life of a trustworthy AI engineer much easier. The InterpretML Python library is intended to do just that; play around with it to round up this first practical session.

InĀ [29]:
## this code is developed by Jose Maria Alonso-Moral

# This is InterpretML's own approach: a 'glass-box' model that aims to provide opaque-model performance with full explainability.
# As such, it is already designed with explanations in mind, and supports them natively.
ebm = ExplainableBoostingClassifier()
ebm.fit(x_tr, y_tr)

# This will provide global insight on the model. You can change tabs to study the effect of individual variables, and also of interactions between pairs of them.
ebm_global = ebm.explain_global()
show(ebm_global)
InĀ [30]:
## this code is developed by Jose Maria Alonso-Moral

# This code explains local instances
instance = 0
ebm_local = ebm.explain_local(x_test.iloc[instance : instance+1], y_test.iloc[instance : instance+1])
show(ebm_local)

Exercise I1.4¶

Now that you've reached the end of this first practical session, you have all the tools needed to autonomously perform the following task. Use the code below to import the SONGS dataset, present on your workspace, which contains 2017 data instances representing songs from a Kaggle competition. The dataset is intended to serve for the classification problem of predicting if the song will be 'Like' or 'Dislike'.

With the dataset imported, follow the process laid out through this notebook to explain the global behaviour of the model. Then, find a song you like (and another song that you don't like) on the dataset (using the code provided), and locally explain the predictions made by a decision tree with a good interpretability-accuracy trade-off and Random Forest. Discuss all your findings.

InĀ [51]:
## this code is developed by Jose Maria Alonso-Moral and Pablo Miguel Perez-Ferreiro

# importing the dataset
file_songs = 'testlib/SONGS/SONGS.arff'
file_songs_data = 'testlib/SONGS/spotifyData.csv'
with open(file_songs) as f:
    songs_data_value, songs_attributes = loadarff(f)
    f.close()
with open(file_songs_data) as f:
    song_names = pd.read_csv(file_songs_data).loc[:, ['artist', 'song_title', 'id']]
    f.close()
songs_fnames=["acoustics","danceability","duration","energy","instrumentalness","key","liveness","loudness","mode","speechiness","tempo","timeSignature","valence"]
songs_class_names = ['Dislike','Like']

# You don't need to perform a train/test split or validation of the models later. 
songs_class_names = np.array(songs_class_names)
songs_feature_names = np.array(songs_attributes.names())
df_songs=pd.DataFrame(songs_data_value)
df_songs.columns = songs_feature_names
songs_target = df_songs.pop('class')
songs_target_onehot = pd.get_dummies(songs_target)[b'2'] 
X_songs = df_songs
y_songs = songs_target_onehot
# You don't need the numerical variant, as you don't need to do the Regression equivalents.

print(f'Class names for the PIMA Dataset:\n\t{", ".join(list(songs_class_names))}\nFeature names for the PIMA Dataset:\n\t{", ".join(list(songs_feature_names))}')
Class names for the PIMA Dataset:
	Dislike, Like
Feature names for the PIMA Dataset:
	acousticness, danceability, duration_ms, energy, instrumentalness, key, liveness, loudness, mode, speechiness, tempo, time_signature, valence, class
InĀ [52]:
# You can use this code to search for text on either the artist or the title of the song and record its ID, to search for it on the actual prediction dataset later.
# If you can't seem to find an artist or song, it may not be on the dataset (which is not that exhaustive) or you may be writing it differently to the dataset's record
# in which case you can try to spell it differently (such as 'michael' instead of 'Michael'). A trick you can use, because str.contains() treats its input like a regular
# expression, is to write just the middle part: instead of 'Billie', just write 'illie' to avoid upper/lowercase conflicts.

artist_songs = song_names[song_names['artist'].str.contains("Michael Jackson")]
named_songs = song_names[song_names['song_title'].str.contains("Hips Don")]
print(artist_songs)
print(named_songs)
               artist                       song_title    id
1827  Michael Jackson                      Billie Jean  1827
1828  Michael Jackson         Beat It - Single Version  1828
1829  Michael Jackson  Black or White - Single Version  1829
1830  Michael Jackson         The Way You Make Me Feel  1830
1831  Michael Jackson                Man In The Mirror  1831
1832  Michael Jackson      P.Y.T. (Pretty Young Thing)  1832
1920  Michael Jackson                Remember the Time  1920
1948  Michael Jackson  Earth Song - Remastered Version  1948
       artist      song_title    id
1927  Shakira  Hips Don't Lie  1927
InĀ [53]:
# 1600 instances for training, which is approximately the 80% of 2017
X_songs_train = X_songs.iloc[0:1600, :]
X_songs_test  = X_songs.iloc[1600:, :]

y_songs_train = y_songs.iloc[0:1600]
y_songs_test  = y_songs.iloc[1600:]
InĀ [54]:
print(X_songs.iloc[1927])
print(X_songs_test.iloc[1927-1600])
acousticness             0.2840
danceability             0.7780
duration_ms         218093.0000
energy                   0.8240
instrumentalness         0.0000
key                     10.0000
liveness                 0.4050
loudness                -5.8920
mode                     0.0000
speechiness              0.0712
tempo                  100.0240
time_signature           4.0000
valence                  0.7670
Name: 1927, dtype: float64
acousticness             0.2840
danceability             0.7780
duration_ms         218093.0000
energy                   0.8240
instrumentalness         0.0000
key                     10.0000
liveness                 0.4050
loudness                -5.8920
mode                     0.0000
speechiness              0.0712
tempo                  100.0240
time_signature           4.0000
valence                  0.7670
Name: 1927, dtype: float64
InĀ [56]:
# We fit a Decision Tree, and a Random Forest.
dtc_songs = tree.DecisionTreeClassifier(max_depth=5)
dtc_songs.fit(X_songs_train, y_songs_train)

rfc_songs = RandomForestClassifier(n_estimators=1000)
rfc_songs.fit(X_songs_train, y_songs_train)

models_songs = [dtc_songs, rfc_songs]
models_names_songs = ['TREE_songs', 'RF_songs']

# Cross-validation for the models just fit above 
print("1) Cross-validation (over training data)")
scorings = ['accuracy', 'f1']  # For binary classification
nF= 5
for model, model_name in zip(models_songs, models_names_songs):
    cv_results = cross_validate(model, X_songs_train, y_songs_train, cv=nF, 
                            scoring=scorings,
                            return_train_score=False)
    print(f'\n\t{model_name}:')
    print(f'\t\tCorrect Classification Rate [Average (St. Dev)] = {np.mean(cv_results["test_accuracy"]):.3f} ({np.std(cv_results["test_accuracy"]):.3f})')
    print(f'\t\tF-Score [Average (St. Dev)] = {np.mean(cv_results["test_f1"]):.3f} ({np.std(cv_results["test_f1"]):.3f})')

# Test with unknown instances
print("\n\n2) Test (with previously unseen data)")
models_acc_songs=[]
target_names = ['class 1', 'class 2']
for model, model_name in zip(models_songs, models_names_songs):
    sc = round(model.score(X_songs_test, y_songs_test), 3) # we round the accuracy to 3 decimals and append it to the list of model accuracies for the test set
    models_acc_songs.append(sc)
    # Mean accuracy of self.predict(x_test) wrt y_test
    print(f'\n\t{model_name}:')
    print(f'\t\tCorrect Classification Rate: {models_acc_songs[-1]:.3f}')
    y_pred= model.predict(X_songs_test)
    print(classification_report(y_songs_test, model.predict(X_songs_test), target_names=target_names))
1) Cross-validation (over training data)

	TREE_songs:
		Correct Classification Rate [Average (St. Dev)] = 0.676 (0.041)
		F-Score [Average (St. Dev)] = 0.761 (0.029)

	RF_songs:
		Correct Classification Rate [Average (St. Dev)] = 0.727 (0.054)
		F-Score [Average (St. Dev)] = 0.811 (0.030)


2) Test (with previously unseen data)

	TREE_songs:
		Correct Classification Rate: 0.542
              precision    recall  f1-score   support

     class 1       1.00      0.54      0.70       417
     class 2       0.00      0.00      0.00         0

    accuracy                           0.54       417
   macro avg       0.50      0.27      0.35       417
weighted avg       1.00      0.54      0.70       417


	RF_songs:
		Correct Classification Rate: 0.458
              precision    recall  f1-score   support

     class 1       1.00      0.46      0.63       417
     class 2       0.00      0.00      0.00         0

    accuracy                           0.46       417
   macro avg       0.50      0.23      0.31       417
weighted avg       1.00      0.46      0.63       417

InĀ [57]:
print(export_text(dtc_songs, feature_names=songs_feature_names[0:-1]))
dot_data_tree_songs = tree.export_graphviz(dtc_songs, out_file=None, filled=True, rounded=True, special_characters=True, feature_names=songs_feature_names[0:-1]) 
graph_tree_songs = graphviz.Source(dot_data_tree_songs) 
graph_tree_songs#.render(format='png') # Big trees are not very easy to examine on the notebook, so you may uncomment the render call to save them to a PNG file so that you can check out the graph elsewhere.
|--- energy <= 0.20
|   |--- tempo <= 182.99
|   |   |--- danceability <= 0.15
|   |   |   |--- class: True
|   |   |--- danceability >  0.15
|   |   |   |--- instrumentalness <= 0.00
|   |   |   |   |--- speechiness <= 0.05
|   |   |   |   |   |--- class: True
|   |   |   |   |--- speechiness >  0.05
|   |   |   |   |   |--- class: False
|   |   |   |--- instrumentalness >  0.00
|   |   |   |   |--- duration_ms <= 173942.50
|   |   |   |   |   |--- class: False
|   |   |   |   |--- duration_ms >  173942.50
|   |   |   |   |   |--- class: False
|   |--- tempo >  182.99
|   |   |--- class: True
|--- energy >  0.20
|   |--- loudness <= -5.63
|   |   |--- instrumentalness <= 0.00
|   |   |   |--- speechiness <= 0.08
|   |   |   |   |--- danceability <= 0.78
|   |   |   |   |   |--- class: False
|   |   |   |   |--- danceability >  0.78
|   |   |   |   |   |--- class: True
|   |   |   |--- speechiness >  0.08
|   |   |   |   |--- energy <= 0.34
|   |   |   |   |   |--- class: False
|   |   |   |   |--- energy >  0.34
|   |   |   |   |   |--- class: True
|   |   |--- instrumentalness >  0.00
|   |   |   |--- acousticness <= 0.86
|   |   |   |   |--- instrumentalness <= 0.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- instrumentalness >  0.00
|   |   |   |   |   |--- class: True
|   |   |   |--- acousticness >  0.86
|   |   |   |   |--- energy <= 0.39
|   |   |   |   |   |--- class: False
|   |   |   |   |--- energy >  0.39
|   |   |   |   |   |--- class: True
|   |--- loudness >  -5.63
|   |   |--- duration_ms <= 259198.00
|   |   |   |--- duration_ms <= 165915.00
|   |   |   |   |--- valence <= 0.18
|   |   |   |   |   |--- class: False
|   |   |   |   |--- valence >  0.18
|   |   |   |   |   |--- class: True
|   |   |   |--- duration_ms >  165915.00
|   |   |   |   |--- instrumentalness <= 0.00
|   |   |   |   |   |--- class: False
|   |   |   |   |--- instrumentalness >  0.00
|   |   |   |   |   |--- class: True
|   |   |--- duration_ms >  259198.00
|   |   |   |--- acousticness <= 0.00
|   |   |   |   |--- class: False
|   |   |   |--- acousticness >  0.00
|   |   |   |   |--- instrumentalness <= 0.00
|   |   |   |   |   |--- class: True
|   |   |   |   |--- instrumentalness >  0.00
|   |   |   |   |   |--- class: True

Out[57]:
No description has been provided for this image
InĀ [58]:
explainerTreeSongs = shap.TreeExplainer(dtc_songs)
shap_values_DTC_songs = explainerTreeSongs(X_songs_test)

explainerRFSongs = shap.TreeExplainer(rfc_songs)
shap_values_RFC_songs = explainerRFSongs(X_songs_test)

explainers_songs = [explainerTreeSongs, explainerRFSongs]
InĀ [59]:
shap.summary_plot(shap_values_DTC_songs[:,:,1])
No description has been provided for this image
InĀ [60]:
shap.summary_plot(shap_values_RFC_songs[:,:,1])
No description has been provided for this image

Global explanation:

In analyzing the decision tree we have observed that the feature energy is the root of the tree and drives the first split so that we can coclude that it is the most important feature for the model. This fact can suggest that the model distinguishes songs depending on how much energetic they are.

After energy, even tempo and loudness features play an important role in the model decision logic, while features such as danceability and duration are used to refine the decisions at deeper levels of the tree, these last features contribute to the final classification but are less importsnt than energy, tempo, and loudness.

The SHAP summary plot provides a complementary point of view. According to SHAP plot, instrumentalness and loudness are the most important features . In particular, for loudness, it is clear that higher values tend to force the prediction toward the ā€œDislikeā€ class, very loud songs are generally less preferred by the model.

Moreover, the SHAP plot shows some outliers related to energy, where very low energy values strongly contribute to a ā€œDislikeā€ prediction.

This shows that, even if energy is an important feature in the tree’s decisions, very high or very low energy values can strongly affect the model’s predictions.

When comparing the SHAP summary plots of the Decision Tree and the Random Forest, we can notice small differences. Since the Random Forest combines the predictions of many trees, it produces a smoother distribution of SHAP values, which reduces variability and the influence of individual splits.

In contrast, the Decision Tree assigns zero or near-zero importance to some features, since a single tree may never use them in its splits. This explains why certain features appear unimportant in the Decision Tree SHAP plot but still have importance in the Random Forest.

In the end,both models are able to identify the same factors that influence musical preferences,the Random Forest provides a more reliable and detailed explanation ,while the Decison Tree provides simpler and easier to understand decision rules.

InĀ [61]:
for instance in [1828-1600, 1927-1600]: # specific cases, feel free to alter them
    print('----------------------------------------------------------------------------------------------------------------------------')
    print(f'\nWORKING WITH INSTANCE {instance}:\nThe real output class is: {y_songs_test.iloc[instance]}')
    for model, model_name, explainer in zip(models_songs, models_names_songs, explainers_songs):
        print(f'The predicted output class by model {model_name} is: {model.predict(X_songs_test.iloc[instance : instance+1])[0]}')
        shap_values = explainer(X_songs_test.iloc[instance : instance+1])
        
        shap_class = shap_values[:, :, 1]

        print("Summary:")
        shap.summary_plot(shap_class)
        
        print("Bar:")
        shap.plots.bar(shap_class)
        print("Waterfall:")
        shap.plots.waterfall(shap_class[0])

        print("Partial dependence:")
        shap.plots.partial_dependence("loudness", model.predict ,X_songs_test)
        print('----------------------------------------------------------------------------------------------------------------------------')
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 228:
The real output class is: False
The predicted output class by model TREE_songs is: False
Summary:
No description has been provided for this image
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF_songs is: True
Summary:
No description has been provided for this image
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------

WORKING WITH INSTANCE 327:
The real output class is: False
The predicted output class by model TREE_songs is: True
Summary:
No description has been provided for this image
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------
The predicted output class by model RF_songs is: False
Summary:
No description has been provided for this image
Bar:
No description has been provided for this image
Waterfall:
No description has been provided for this image
Partial dependence:
No description has been provided for this image
----------------------------------------------------------------------------------------------------------------------------

Local explanation:

The selected songs for local analysis are Beat It by Michael Jackson (ID 1828 in the original dataset, corresponding to index 228 in the test set) and Hips Don’t Lie by Shakira (ID 1927, corresponding to index 327 in the test set). Both songs were analyzed using the Decision Tree and the Random Forest in order to compare their local explanations.

For Beat It, the local SHAP explanations show that loudness and instrumentalness are the most influential features, and both contribute negatively to the probability of the song being liked. This suggests that, for this specific instance, high loudness levels and low instrumental content push the prediction toward the ā€œDislikeā€ class. On the other hand, danceability has a positive contribution, partially compensating for the negative effect of loudness and instrumentalness and supporting the ā€œLikeā€ prediction. This combination of features reflects a trade-off between rhythmic appeal and production characteristics in the model’s decision.

In the case of Hips Don’t Lie, the same three features—loudness, danceability, and instrumentalness—also appear as the most important. However, their effects differ depending on the model. In the Decision Tree, loudness contributes positively to the prediction, indicating that higher loudness levels increase the likelihood of the song being liked for this instance. When analyzing the Random Forest, the importance of loudness and danceability decreases, while instrumentalness becomes the dominant factor by a large margin. This change reflects the more stable and averaged behavior of the Random Forest, which smooths the influence of individual features and relies more heavily on consistent patterns across many trees.

Overall, the local explanations highlight how the same features can have different impacts depending on both the specific song and the model used, emphasizing the value of local interpretability tools such as SHAP.

InĀ [62]:
# Building and visualizing the Pareto front with predictive accuracy versus number of leaf nodes, which can be understood as a surrogate for interpretability (remember
# how the largest tree was much harder to visually inspect)

limits = [min([dtc_songs.get_n_leaves()])*0.75, max([dtc_songs.get_n_leaves()])*1.5]
# Random Forest is a black box model and there is no easy, fair way to give it a nominal leaf node amount. We settle here for multiplying the largest tree x1.5 (penalty for opaqueness).
x_axis_songs = [dtc_songs.get_n_leaves(), limits[1]]

print("The data for the Pareto front is as follows:")
for name, accuracy, leaves in zip(models_names_songs, models_acc_songs, x_axis_songs):
    print(f'\t-For model {name}, accuracy is {accuracy} with a total of {leaves} leaf nodes.')
    
plt.figure(figsize=[15,10])
plot_pareto_front(x_axis_songs, models_acc_songs, models_names_songs, 'Accuracy (Classification Ratio)','Interpretability (Num of rules / leaves)', limits[0], limits[1])
The data for the Pareto front is as follows:
	-For model TREE_songs, accuracy is 0.542 with a total of 21 leaf nodes.
	-For model RF_songs, accuracy is 0.458 with a total of 31.5 leaf nodes.
No description has been provided for this image
InĀ [63]:
# Building and visualizing the Pareto front with Accuracy versus Shap Length

x_axis_songs = [get_shap_explanation_length(explainer.shap_values(X_songs_test)[:,:,1]) for explainer in explainers_songs]
limits = [min(x_axis_songs)*0.99, max(x_axis_songs)*1.01]

print("The data for the Pareto front is as follows:")
for name, accuracy, shap_l in zip(models_names_songs, models_acc_songs, x_axis_songs):
    print(f'\t-For model {name}, accuracy is {accuracy} with a SHAP length of {shap_l}.')

plt.figure(figsize=[15,10])
plot_pareto_front(x_axis_songs,models_acc_songs,models_names_songs,'Accuracy (Classification Ratio)','Interpretability (Shap Length)',limits[0],limits[1])
The data for the Pareto front is as follows:
	-For model TREE_songs, accuracy is 0.542 with a SHAP length of 4879.
	-For model RF_songs, accuracy is 0.458 with a SHAP length of 4866.
No description has been provided for this image